{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 功能丰富的推荐系统\n",
    "\n",
    "交互数据是用户偏好和兴趣的最基本指示。它在以前引入的模型中起着关键作用。然而，交互数据通常非常稀疏，有时会有噪声。为了解决这个问题，我们可以将诸如项目特征、用户配置文件，甚至是交互发生在哪个上下文中的辅助信息集成到推荐模型中。利用这些特性有助于提出建议，因为这些特性可以有效预测用户的兴趣，尤其是在缺少交互数据时。因此，推荐模型还必须能够处理这些特性，并为模型提供一些内容/上下文感知。为了演示这种类型的推荐模型，我们引入了另一项在线广告推荐点击率（CTR）任务：:cite:`McMahan.Holt.Sculley.ea.2013`，并提供了匿名广告数据。有针对性的广告服务已经引起了广泛的关注，通常被设计成推荐引擎。推荐符合用户个人品味和兴趣的广告对于提高点击率非常重要。\n",
    "\n",
    "\n",
    "数字营销人员使用在线广告向客户展示广告。点击率是一个衡量广告主在每一次印象中收到的广告点击次数的指标，它用公式计算的百分比表示：\n",
    "\n",
    "$$ \\text{CTR} = \\frac{\\#\\text{Clicks}} {\\#\\text{Impressions}} \\times 100 \\% .$$\n",
    "\n",
    "点击率是指示预测算法有效性的重要信号。点击率预测是一项预测网站被点击的可能性的任务。CTR预测模型不仅可用于目标广告系统，也可用于一般项目（如电影、新闻、产品）推荐系统、电子邮件活动，甚至搜索引擎。它还与用户满意度、转化率密切相关，有助于设定活动目标，因为它可以帮助广告商设定现实的期望。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 1,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "\n",
    "import ai.djl.training.dataset.Record;\n",
    "import com.google.gson.Gson;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 2
   },
   "source": [
    "## 在线广告数据集\n",
    "\n",
    "随着互联网和移动技术的长足进步，在线广告已成为互联网行业的一项重要收入来源，并产生了绝大部分收入。展示相关广告或激起用户兴趣的广告非常重要，这样，临时访客就可以转化为付费客户。我们介绍的数据集是一个在线广告数据集。它由34个字段组成，第一列表示目标变量，指示是否单击了广告（1）或未单击广告（0）。所有其他列都是分类特征。这些列可能表示广告id、站点或应用程序id、设备id、时间、用户配置文件等。由于匿名化和隐私问题，这些特征的真正语义尚未公开。\n",
    "\n",
    "以下代码从服务器下载数据集并将其保存到本地数据文件夹中。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "15"
    },
    "origin_pos": 3,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "var url = \"http://d2l-data.s3-accelerate.amazonaws.com/ctr.zip\";\n",
    "ZipUtils.unzip(new URL(url).openStream(), Paths.get(\"./\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 4
   },
   "source": [
    "有一个训练集和一个测试集，分别由15000和3000个样本/行组成。\n",
    "\n",
    "## 数据集封装\n",
    "\n",
    "为了方便数据加载，我们实现了一个`CTRDataset`，它从CSV文件加载广告数据集，并且可以由`DataLoader`使用。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class CTRDataset extends ArrayDataset {\n",
    "\n",
    "    private boolean prepared;\n",
    "    private NDManager manager = Engine.getInstance().newBaseManager();\n",
    "    private List<Long[]> oneHotFeatures;\n",
    "    private List<Float> labelList;\n",
    "\n",
    "    private CTRDataset(Builder builder) {\n",
    "        super(builder);\n",
    "        this.oneHotFeatures = builder.oneHotFeatures;\n",
    "        this.labelList = builder.label;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void prepare(Progress progress) throws IOException {\n",
    "        if (prepared) {\n",
    "            return;\n",
    "        }\n",
    "        data = new NDArray[oneHotFeatures.size()];\n",
    "        labels = new NDArray[labelList.size()];\n",
    "        for (int i = 0; i < data.length; i++) {\n",
    "            data[i] = manager.create(Arrays.stream(oneHotFeatures.get(i)).mapToLong(Long::longValue).toArray());\n",
    "            labels[i] = manager.create(labelList.get(i));\n",
    "        }\n",
    "        prepared = true;\n",
    "    }\n",
    "\n",
    "    /**\n",
    "     * {@inheritDoc}\n",
    "     */\n",
    "    @Override\n",
    "    public Record get(NDManager manager, long index) {\n",
    "        NDList datum = new NDList();\n",
    "        NDList label = new NDList();\n",
    "\n",
    "        datum.add(data[(int) index]);\n",
    "        if (labels != null) {\n",
    "            label.add(labels[(int) index]);\n",
    "        }\n",
    "        datum.attach(manager);\n",
    "        label.attach(manager);\n",
    "        return new Record(datum, label);\n",
    "    }\n",
    "\n",
    "\n",
    "    public static Builder builder() {\n",
    "        return new Builder();\n",
    "    }\n",
    "\n",
    "    public static final class Builder extends BaseBuilder<Builder> {\n",
    "\n",
    "        private long numFeatures;\n",
    "        private long featureThreshold;\n",
    "        private String fileName;\n",
    "        // feature id, category String, category code\n",
    "        private Map<Long, Map<String, Long>> featureMap = new ConcurrentHashMap<>();\n",
    "        // feature id, category String, category count\n",
    "        private Map<Long, Map<String, Long>> featureCount = new ConcurrentHashMap<>();\n",
    "        private Map<Long, Long> defaultValues = new ConcurrentHashMap<>();\n",
    "        private List<String[]> features = new ArrayList<>();\n",
    "        private List<Float> label = new ArrayList<>();\n",
    "        private Long[] fieldDim;\n",
    "        private Long[] offset;\n",
    "        private List<Long[]> oneHotFeatures = new ArrayList<>();\n",
    "        private String outputDir;\n",
    "\n",
    "        Builder() {\n",
    "        }\n",
    "\n",
    "        @Override\n",
    "        protected Builder self() {\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder setFileName(String fileName) {\n",
    "            this.fileName = fileName;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder optNumFeatures(long numFeatures) {\n",
    "            this.numFeatures = numFeatures;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder optFeatureThreshold(long featureThreshold) {\n",
    "            this.featureThreshold = featureThreshold;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder optMapOutputDir(String outputDir) {\n",
    "            this.outputDir = outputDir;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        CTRDataset build() throws IOException {\n",
    "\n",
    "            try (BufferedReader reader = Files.newBufferedReader(Paths.get(this.fileName))) {\n",
    "                String line;\n",
    "                while ((line = reader.readLine()) != null) {\n",
    "                    String[] record = line.trim().split(\"\\t\");\n",
    "                    if (record.length != this.numFeatures + 1) {\n",
    "                        continue;\n",
    "                    }\n",
    "                    label.add(Float.parseFloat(record[0]));\n",
    "                    for (int i = 1; i < numFeatures + 1; i++) {\n",
    "                        Map<String, Long> count = featureCount.computeIfAbsent((long) i, k -> new ConcurrentHashMap<>());\n",
    "                        // 字符串的增量计数\n",
    "                        count.merge(record[i], 1L, Long::sum);\n",
    "                    }\n",
    "                    features.add(Arrays.copyOfRange(record, 1, record.length));\n",
    "                }\n",
    "            }\n",
    "            fieldDim = new Long[(int) numFeatures];\n",
    "            offset = new Long[(int) numFeatures];\n",
    "            // 减少 class 频率\n",
    "            for (long i = 1L; i < numFeatures + 1; i++) {\n",
    "                featureCount.get(i).values().removeIf(value -> value < featureThreshold);\n",
    "                Map<String, Long> reducedFeatures = featureCount.get(i);\n",
    "                Map<String, Long> featureIndex = new ConcurrentHashMap<>();\n",
    "                long index = 0;\n",
    "                for (String feature : reducedFeatures.keySet()) {\n",
    "                    featureIndex.put(feature, index);\n",
    "                    index++;\n",
    "                }\n",
    "                featureMap.put(i, featureIndex);\n",
    "                defaultValues.put(i, (long) featureIndex.size());\n",
    "                fieldDim[(int) i - 1] = (long) featureIndex.size() - 1;\n",
    "            }\n",
    "            long sum = 0;\n",
    "            for (int i = 0; i < fieldDim.length; i++) {\n",
    "                offset[i] = sum;\n",
    "                sum += fieldDim[i];\n",
    "            }\n",
    "\n",
    "            for (String[] feature : features) {\n",
    "                Long[] oneHot = new Long[feature.length];\n",
    "                for (int i = 0; i < oneHot.length; i++) {\n",
    "                    oneHot[i] = featureMap.get((long) i + 1).getOrDefault(feature[i], defaultValues.get((long) i + 1)) + offset[i];\n",
    "                }\n",
    "                oneHotFeatures.add(oneHot);\n",
    "            }\n",
    "            // 保存特征映射和默认值以进行推断\n",
    "            if (outputDir != null) {\n",
    "                saveMap(featureMap, outputDir, \"feature_map.json\");\n",
    "                saveMap(defaultValues, outputDir, \"defaults.json\");\n",
    "            }\n",
    "\n",
    "            return new CTRDataset(this);\n",
    "        }\n",
    "\n",
    "        private void saveMap(Map map, String outputDir, String fileName) throws IOException {\n",
    "            Gson gson = new Gson();\n",
    "            FileWriter writer = new FileWriter(outputDir + \"/\" + fileName);\n",
    "            gson.toJson(map, writer);\n",
    "            writer.flush();\n",
    "            writer.close();\n",
    "        }\n",
    "\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "下面的示例加载训练数据并打印出第一条记录。我们还需要保存特征映射和默认值以进行推断。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "16"
    },
    "origin_pos": 7,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "CTRDataset data = CTRDataset.builder()\n",
    "                .optFeatureThreshold(4)\n",
    "                .optNumFeatures(34)\n",
    "                .setFileName(\"./ctr/train.csv\")\n",
    "                .optMapOutputDir(\"./\")\n",
    "                .setSampling(1, true)\n",
    "                .build();\n",
    "data.prepare();\n",
    "NDManager manager = NDManager.newBaseManager();\n",
    "Record record = data.get(manager, 0);\n",
    "System.out.println(record.getData().singletonOrThrow());\n",
    "System.out.println(record.getLabels().singletonOrThrow());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 8
   },
   "source": [
    "可以看出，所有34个字段都是分类特征。每个值表示对应条目的一个热索引。标签$0$表示未单击该标签。 该`CTRDataset`数据集还可用于加载其他数据集，如Criteo图片广告挑战赛数据集 [Dataset](https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/) 和Avazu点击率预测数据集 [Dataset](https://www.kaggle.com/c/avazu-ctr-prediction).  \n",
    "\n",
    "## 总结\n",
    "* 点击率是衡量广告系统和推荐系统有效性的重要指标。\n",
    "* 点击率预测通常转化为二进制分类问题。目标是根据给定的特性预测是否会单击广告/项目。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 能否使用提供的`CTRDataset`加载Criteo和Avazu数据集。值得注意的是，Criteo数据集包含实值特征，因此您可能需要稍微修改代码。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
